import argparse

def parse_args_batch_train():
    parser = argparse.ArgumentParser(description='batch_train')
    parser.add_argument('--dataset', default='miniImageNet', help='training base model') 
    parser.add_argument('--data_path', default='./miniImagenet', help='train data path') 
    parser.add_argument('--image_size'  , default=224, type=int,  help='image size') 
    parser.add_argument('--base_class' , default=64, type=int, help='total number of classes in base class') 
    parser.add_argument('--batch_size' , default=16, type=int, help='total number of batch size in base class')
    parser.add_argument('--feature_size' , default=512, type=int, help='feature_size')
    parser.add_argument('--backbone', default='ResNet10', help='backbone type')
    parser.add_argument('--list_of_out_dims', default=[64,128,256,512], help='every block output')
    parser.add_argument('--list_of_stride', default=[1,2,2,2], help='every block conv stride')
    parser.add_argument('--list_of_dilated_rate', default=[1,1,1,1], help='dilated conv')
    parser.add_argument('--method', default='Linear_Classifier', help='Linear_Classifier/ProtoNet') 
    parser.add_argument('--train_aug', default='True',  help='perform data augmentation or not during training ') 
    parser.add_argument('--save_freq', default=100, type=int, help='Save frequency')
    parser.add_argument('--save_dir', default='./logs', help='Save dir')
    parser.add_argument('--epoch', default=400, type=int, help ='total batch train epoch')  
    parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
    parser.add_argument('--weight-decay', default=5e-4, type=float, help='weight decay (default: 5e-4)')
    parser.add_argument('--seed' , default=1111, type=int, help='feature_size')
    parser.add_argument('--model_path', default='./log/best_model.tar', help='model_path')
    parser.add_argument('--lamba_alignment_loss_cross', default=1.0, type=float, help='lamba_alignment_loss_cross')
    parser.add_argument('--lamba_loss_proto', default=1.0, type=float, help='lamba_loss_proto')
    parser.add_argument('--alignment_loss_low', default=1.0, type=float, help='alignment_loss_low')
    parser.add_argument('--alignment_loss_high', default=1.0, type=float, help='alignment_loss_high')
    parser.add_argument('--m_low', type=float, default=0.998, help='m_low of moment')
    parser.add_argument('--m_high', type=float, default=0.998, help='m_high of moment')
    parser.add_argument('--cross_alignment_loss', default=1.0, type=float, help='cross_alignment_loss')
    parser.add_argument('--lamba_alignment_loss', default=1.0, type=float, help='lamba_alignment_loss')
    
    return parser.parse_args()

def parse_args_eposide_train():
    parser = argparse.ArgumentParser(description='eposide_train')
    parser.add_argument('--dataset', default='miniImageNet', help='training base model') 
    parser.add_argument('--data_path', default='./miniImagenet', help='train data path')
    parser.add_argument('--image_size'  , default=224, type=int,  help='image size') 
    parser.add_argument('--base_class' , default=64, type=int, help='total number of classes in in base class') 
    parser.add_argument('--backbone', default='ResNet10', help='backbone type')
    parser.add_argument('--list_of_out_dims', default=[64,128,256,512], help='every block output')
    parser.add_argument('--list_of_stride', default=[1,2,2,2], help='every block conv stride')
    parser.add_argument('--list_of_dilated_rate', default=[1,1,1,1], help='dilated conv') 
    parser.add_argument('--method', default='ProtoNet', help='Linear_Classifier/ProtoNet') 
    parser.add_argument('--train_aug', default='True',  help='perform data augmentation or not during training ') 
    parser.add_argument('--save_freq', default=100, type=int, help='Save frequency')
    parser.add_argument('--save_dir', default='./logs', help='Save dir')
    parser.add_argument('--epoch', default=400, type=int, help ='total batch train epoch')  
    parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
    parser.add_argument('--weight-decay', default=5e-4, type=float, help='weight decay (default: 5e-4)')
    parser.add_argument('--n_way', default=5, type=int,  help='class num to classify for every task')
    parser.add_argument('--n_support', default=5, type=int,  help='number of labeled data in each class, same as n_support') 
    parser.add_argument('--n_query', default=15, type=int,  help='number of test data in each class, same as n_query') 
    parser.add_argument('--n_eposide', default=100, type=int, help ='total task every epoch') # for meta-learning methods, each epoch contains 100 episodes
    parser.add_argument('--model_path', default='./log/best_model.tar', help='model_path')
    parser.add_argument('--seed' , default=1111, type=int, help='feature_size')
    parser.add_argument('--feature_size' , default=512, type=int, help='feature_size')
    parser.add_argument('--lamba_kl_loss', default=1.0, type=float, help='lamba_kl_loss')
    parser.add_argument('--m', type=float, default=0.998, help='epsilon of moment')
    parser.add_argument('--lamba1', type=float, default=1.0, help='lamba_cross') 
    parser.add_argument('--topK' , default=100, type=int, help='topK')
    parser.add_argument('--topK1' , default=200, type=int, help='topK1')
    parser.add_argument('--topK2' , default=100, type=int, help='topK2')
    parser.add_argument('--n_clusters' , default=3, type=int, help='n_clusters')
    parser.add_argument('--crop_size' , default=96, type=int, help='crop_size')
    parser.add_argument('--crop_num' , default=3, type=int, help='crop_num')
    parser.add_argument('--min_scale_crops', default=0.14, type=float, help='min_scale_crops')
    parser.add_argument('--max_scale_crops', default=1.0, type=float, help='max_scale_crops')
    parser.add_argument('--patch_size' , default=32, type=int, help='patch_size')
    parser.add_argument('--lamba_variance_loss', default=1.0, type=float, help='variance_loss')
    parser.add_argument('--variance_loss_beta', default=1e-04, type=float, help='variance_loss_beta')
    parser.add_argument('--clip_model_path', default='./clip/ ', help='clip_model_path')
    parser.add_argument('--current_data_path', default='./datasets/ISIC', help='ISIC_data_path')
    parser.add_argument('--current_class', default=7, type=int, help='total number of classes in ISIC')
    parser.add_argument('--test_n_eposide', default=600, type=int, help ='total task every epoch') # for meta-learning methods, each epoch contains 100 episodes
    parser.add_argument('--kl_beta', default=1.0, type=float, help='kl_beta')
    parser.add_argument('--ju_alpha', default=1.0, type=float, help='ju_alpha')
    parser.add_argument('--lamba_alignment_loss', default=1.0, type=float, help='lamba_alignment_loss')
    parser.add_argument('--lamba_alignment_loss_cross', default=1.0, type=float, help='lamba_alignment_loss_cross')
    parser.add_argument('--lamba_loss_proto', default=1.0, type=float, help='lamba_loss_proto')
    parser.add_argument('--alignment_loss_low', default=1.0, type=float, help='alignment_loss_low')
    parser.add_argument('--alignment_loss_high', default=1.0, type=float, help='alignment_loss_high')
    parser.add_argument('--m_low', type=float, default=0.998, help='m_low of moment')
    parser.add_argument('--m_high', type=float, default=0.998, help='m_high of moment')
    parser.add_argument('--cross_alignment_loss', default=1.0, type=float, help='cross_alignment_loss')
    parser.add_argument('--lamba_loss_ce_fuse', default=1.0, type=float, help='lamba_loss_ce_fuse')
    parser.add_argument('--lamba_reconstruct_loss', default=1.0, type=float, help='lamba_reconstruct_loss')
    
    
    return parser.parse_args()

def parse_args_test():
    parser = argparse.ArgumentParser(description='test')
    parser.add_argument('--EuroSAT_data_path', default='./datasets/EuroSAT/2750', help='EuroSAT_data_path')
    parser.add_argument('--mini_test_data_path', default='./test', help='mini_test_data_path') 
    parser.add_argument('--mini_test_class' , default=20, type=int, help='total number of classes in EuroSAT') 
    parser.add_argument('--image_size'  , default=224, type=int,  help='image size') 
    parser.add_argument('--EuroSAT_class' , default=10, type=int, help='total number of classes in EuroSAT') 
    parser.add_argument('--feature_size' , default=512, type=int, help='feature_size')
    parser.add_argument('--list_of_out_dims', default=[64,128,256,512], help='every block output')
    parser.add_argument('--list_of_stride', default=[1,2,2,2], help='every block conv stride')
    parser.add_argument('--list_of_dilated_rate', default=[1,1,1,1], help='dilated conv') 
    parser.add_argument('--model_path', default='./log/best_model.tar', help='model_path')
    parser.add_argument('--n_way', default=5, type=int,  help='class num to classify for every task')
    parser.add_argument('--n_support', default=5, type=int,  help='number of labeled data in each class, same as n_support') 
    parser.add_argument('--n_query', default=15, type=int,  help='number of test data in each class, same as n_query') 
    parser.add_argument('--test_n_eposide', default=600, type=int, help ='total task every epoch') # for meta-learning methods, each epoch contains 100 episodes
    parser.add_argument('--seed' , default=1111, type=int, help='feature_size')
    
    parser.add_argument('--current_data_path', default='./datasets/ISIC', help='ISIC_data_path')
    parser.add_argument('--current_class', default=7, type=int, help='total number of classes in ISIC')

    parser.add_argument('--DS_N', default=750, type=int,  help='number of gen data in each class') 
    
    parser.add_argument('--n_aug_support_samples', default=5, type=int, help='total number of n_aug_support_samples')
    parser.add_argument('--ft_lr', default=0.01, type=float, help='learning rate')
    parser.add_argument('--pseudo_n_query', default=5, type=int,  help='number of test data in each class, same as pseudo n_query') 
    parser.add_argument('--pseudo_total_epoch', default=100, type=int, help='initial learning rate')
    parser.add_argument('--topK' , default=10, type=int, help='topK')
    parser.add_argument('--topK1' , default=200, type=int, help='topK1')
    parser.add_argument('--topK2' , default=100, type=int, help='topK2')
    parser.add_argument('--crop_num' , default=3, type=int, help='crop_num')
    parser.add_argument('--kmeans_niter' , default=10, type=int, help='kmeans_niter')  
    parser.add_argument('--lamba1', type=float, default=1.0, help='lamba_cross') 
    parser.add_argument('--crop_size' , default=96, type=int, help='crop_size')
    parser.add_argument('--patch_size' , default=112, type=int, help='patch_size')
    
    parser.add_argument('--tr_N' , default=7, type=int, help='tr_N')
    parser.add_argument('--tr_K' , default=10, type=int, help='tr_K')
    
    parser.add_argument('--base_means_path', default='./dc_feature/miniImageNet_base_means.npy', help='base_means_path')
    parser.add_argument('--base_cov_path', default='./dc_feature/miniImageNet_base_cov.npy', help='base_cov_path')
    parser.add_argument('--clip_model_path', default='./clip/ ', help='clip_model_path')
    

    parser.add_argument('--dino_model_path', default='dino_deitsmall8_pretrain.pth', help='dino_model_path')
    parser.add_argument('--dino_arch', default='vit_small', help='dino_arch')
    parser.add_argument('--dino_patch_size' , default=16, type=int, help='dino_patch_size')
    
    parser.add_argument('--dataset_type', default='datasetname', help='dataset_type')
    
    
    
    parser.add_argument('--sample_number' , default=50, type=int, help='sample_number')
    parser.add_argument('--ju_alpha', default=1.0, type=float, help='ju_alpha')
    parser.add_argument('--gama', default=1.0, type=float, help='gama')
    parser.add_argument('--beta', default=0.5, type=float, help='beta')
    
    parser.add_argument('--data_path', default='./miniImagenet', help='train data path')
    
    parser.add_argument('--data_name', default='ISIC', help='data_name')
    parser.add_argument('--model_name', default='dino', help='model_name')
    parser.add_argument('--num_class' , default=7, type=int, help='num_class')
    parser.add_argument('--sup_shot' , default=5, type=int, help='sup_shot')
    parser.add_argument('--unsup_shot' , default=10, type=int, help='unsup_shot')
    parser.add_argument('--test_shot' , default=100, type=int, help='test_shot')
    
    parser.add_argument('--deepemd', default='fcn', help='deepemd')
    parser.add_argument('--deepemd_model_dir', default='./deepEMD_model/max_acc.pth', help='deepemd_model_dir')
    parser.add_argument('-feature_pyramid', type=str, default=None)
    # solver
    parser.add_argument('-solver', type=str, default='opencv', choices=['opencv'])
    # SFC
    parser.add_argument('-sfc_lr', type=float, default=100)
    parser.add_argument('-sfc_wd', type=float, default=0, help='weight decay for SFC weight')
    parser.add_argument('-sfc_update_step', type=float, default=100)
    parser.add_argument('-sfc_bs', type=int, default=4)
    parser.add_argument('-norm', type=str, default='center', choices=[ 'center'])
    parser.add_argument('-metric', type=str, default='cosine', choices=[ 'cosine' ])
    parser.add_argument('-temperature', type=float, default=12.5)
    parser.add_argument('--gpu', default='0,1')
    parser.add_argument('--shot', type=int, default=1)
    parser.add_argument('-way', type=int, default=5)
    parser.add_argument('--config', default='./configs/test_few_shot.yaml')
    
    parser.add_argument('--save_name', type=str, default='source_data.npy')
    parser.add_argument('--image_path', type=str, default='source_data.npy')
    parser.add_argument('--save_path', type=str, default='source_data.npy')
    

    return parser.parse_args()










